using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using NetCorePal.Extensions.CodeAnalysis.SourceGenerators;

namespace NetCorePal.Extensions.CodeAnalysis.SourceGenerators;

[Generator]
public class CommandHandlerMetadataGenerator : IIncrementalGenerator
{
    public void Initialize(IncrementalGeneratorInitializationContext context)
    {
        var typeDeclarations = context.SyntaxProvider
            .CreateSyntaxProvider(
                predicate: (node, _) => node is ClassDeclarationSyntax || node is RecordDeclarationSyntax,
                transform: (ctx, _) => ctx.Node)
            .Where(n => n is ClassDeclarationSyntax || n is RecordDeclarationSyntax);

        var compilationAndTypes = context.CompilationProvider.Combine(typeDeclarations.Collect());

        context.RegisterSourceOutput(compilationAndTypes, (spc, source) =>
        {
            var (compilation, typeNodes) =
                ((Compilation, System.Collections.Immutable.ImmutableArray<SyntaxNode>))source;
            var handlerMetas = new List<(string HandlerType, string CommandType, string[] AggregateTypes)>();

            foreach (var typeDecl in typeNodes)
            {
                var semanticModel = compilation.GetSemanticModel(typeDecl.SyntaxTree);
                var symbol = semanticModel.GetDeclaredSymbol(typeDecl) as INamedTypeSymbol;
                if (symbol == null) continue;
                // 使用扩展方法判断是否为CommandHandler
                if (!symbol.IsCommandHandler()) continue;
                // 获取命令类型
                var commandTypeSymbol = symbol.GetCommandFromCommandHandler();
                if (commandTypeSymbol == null) continue;
                var commandType = commandTypeSymbol.ToDisplayString();
                var handlerType = symbol.ToDisplayString();

                // 查找所有方法体中出现的聚合根类型
                var aggregateTypes = new HashSet<string>();
                foreach (var method in symbol.GetMembers().OfType<IMethodSymbol>())
                {
                    foreach (var syntaxRef in method.DeclaringSyntaxReferences)
                    {
                        var methodSyntax = syntaxRef.GetSyntax() as MethodDeclarationSyntax;
                        if (methodSyntax?.Body == null) continue;
                        foreach (var node in methodSyntax.Body.DescendantNodes())
                        {
                            // new 操作
                            if (node is ObjectCreationExpressionSyntax objectCreation)
                            {
                                var typeInfo = semanticModel.GetTypeInfo(objectCreation).Type as INamedTypeSymbol;
                                if (typeInfo != null && typeInfo.IsAggregateRoot())
                                {
                                    aggregateTypes.Add(typeInfo.ToDisplayString());
                                }
                            }
                            // 方法调用
                            else if (node is InvocationExpressionSyntax invocation)
                            {
                                var symbolInfo = semanticModel.GetSymbolInfo(invocation);
                                if (symbolInfo.Symbol is IMethodSymbol methodSymbol)
                                {
                                    var containingType = methodSymbol.ContainingType as INamedTypeSymbol;
                                    if (containingType != null && containingType.IsAggregateRoot())
                                    {
                                        aggregateTypes.Add(containingType.ToDisplayString());
                                    }
                                }
                            }
                        }
                    }
                }
                // 收集完聚合类型后，添加到handlerMetas
                handlerMetas.Add((handlerType, commandType, aggregateTypes.ToArray()));
            }

            if (handlerMetas.Count > 0)
            {
                var sb = new StringBuilder();
                sb.AppendLine("// <auto-generated/>");
                sb.AppendLine("using System;");
                sb.AppendLine("using NetCorePal.Extensions.CodeAnalysis.Attributes;");
                foreach (var (handlerType, commandType, aggregateTypes) in handlerMetas)
                {
                    var aggs = aggregateTypes.Length > 0
                        ? string.Join(", ", aggregateTypes.Select(a => "\"" + a + "\""))
                        : string.Empty;
                    if (string.IsNullOrEmpty(aggs))
                    {
                        sb.AppendLine($"[assembly: CommandHandlerMetadataAttribute(\"{handlerType}\", \"{commandType}\" )]");
                    }
                    else
                    {
                        sb.AppendLine($"[assembly: CommandHandlerMetadataAttribute(\"{handlerType}\", \"{commandType}\", {aggs})]");
                    }
                }
                spc.AddSource("CommandHandlerMetadata.g.cs", sb.ToString());
            }
        });
    }
}
